{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-01T06:48:09.450393Z",
     "start_time": "2020-03-01T06:48:09.123812Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "import zipfile\n",
    "\n",
    "root_path = './../datasets'\n",
    "processed_folder =  os.path.join(root_path)\n",
    "\n",
    "zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot_standard.zip'), 'r')\n",
    "zip_ref.extractall(root_path)\n",
    "zip_ref.close()\n",
    "\n",
    "root_dir = './../datasets/omniglot/python'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据预处理\n",
    "\n",
    "import torchvision.transforms as transforms\n",
    "from PIL import Image\n",
    "\n",
    "'''\n",
    "an example of img_items:\n",
    "( '0709_17.png',\n",
    "  'Alphabet_of_the_Magi/character01',\n",
    "  './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')\n",
    "'''\n",
    "def find_classes(root_dir):\n",
    "    img_items = []\n",
    "    for (root, dirs, files) in os.walk(root_dir): \n",
    "        for file in files:\n",
    "            if (file.endswith(\"png\")):\n",
    "                r = root.split('/')\n",
    "                img_items.append((file, r[-2] + \"/\" + r[-1], root))\n",
    "    print(\"== Found %d items \" % len(img_items))\n",
    "    return img_items\n",
    "\n",
    "## 构建一个词典{class:idx}\n",
    "def index_classes(items):\n",
    "    class_idx = {}\n",
    "    count = 0\n",
    "    for item in items:\n",
    "        if item[1] not in class_idx:\n",
    "            class_idx[item[1]] = count\n",
    "            count += 1\n",
    "    print('== Found {} classes'.format(len(class_idx)))\n",
    "    return class_idx\n",
    "        \n",
    "\n",
    "img_items =  find_classes(root_dir)\n",
    "class_idx = index_classes(img_items)\n",
    "\n",
    "\n",
    "temp = dict()\n",
    "for imgname, classes, dirs in img_items:\n",
    "    img = '{}/{}'.format(dirs, imgname)\n",
    "    label = class_idx[classes]\n",
    "    transform = transforms.Compose([lambda img: Image.open(img).convert('L'),\n",
    "                              lambda img: img.resize((28,28)),\n",
    "                              lambda img: np.reshape(img, (28,28,1)),\n",
    "                              lambda img: np.transpose(img, [2,0,1]),\n",
    "                              lambda img: img/255.\n",
    "                              ])\n",
    "    img = transform(img)\n",
    "    if label in temp.keys():\n",
    "        temp[label].append(img)\n",
    "    else:\n",
    "        temp[label] = [img]\n",
    "print('begin to generate omniglot.npy')\n",
    "## 移除标签信息，每个标签包含20个样本\n",
    "img_list = []\n",
    "for label, imgs in temp.items():\n",
    "    img_list.append(np.array(imgs))\n",
    "img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]\n",
    "print('data shape:{}'.format(img_list.shape)) # (1623, 20, 1, 28, 28)\n",
    "temp = []\n",
    "np.save(os.path.join(root_dir, 'omniglot.npy'), img_list)\n",
    "print('end.')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wangkai3.6",
   "language": "python",
   "name": "wangkai3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
